# Copyright 2022 Huawei Technologies Co., Ltd
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os

import numpy as np
import tensorflow as tf
from src.losses.losses import get_loss
from src.losses.modules.adversarial import build_adversarial_loss
from src.losses.modules.perceptual import build_content_style_loss
from src.losses.modules.perceptual import build_perceptual_loss
from src.ops.edge import get_edges
from src.networks.register import RegisteredModel
from src.runner.common import name_space
from src.utils.logger import logger
from src.utils.utils import convert_to_dict


class Base(object, metaclass=RegisteredModel):
    """Base class for all the video processing models.

    Attributes:
        cfg: yacs CfgNode. Global configuration.
        model_name: str, model name.
        scale: int, output scale w.r.t input, e.g. EDVR output is 4x the scale
            of the input.
        num_net_input_frames: int, the total number of input frames to the
            network. For EDVR, the default is 5. See src.utils.default.
        num_net_output_frames: int, the number of output result frames by the
            network. Default is 1. Also see src.utils.default.
        num_data_lq_frames: int, the total number of lq frames in each case
            generated by the datasets. Typically it can be the same with
            ``num_net_input_frames``. But if one is to use temporal supervision
            and there will be multiple output frames. In this case, see the
            example below.
        num_data_gt_frames: int, the total number of the hq frames in each case
            produced by the dataset. Typically it is the same with
            ``num_net_output_frames``.
        input_color_space: int, the color space of the input frames. Default to
            ``rgb``.
        num_in_channels: int, number of the channels of the input frames.
            Corresponds to ``input_color_space``.
        is_train: boolean, whether the model is in training phase. Determined by
            the ``cfg.mode``.
        generative_model_scope: str, top scope name for the tensorflow graph.
            Default value is 'G'.
        output_dir: str, path to dump the summary.

    Example:
        The most confusion configuration may be the ``num_**_frames``. Here is
        an example of the basic scenario (multi-input frames and single center
        output frame):

        Frame 1 -----> |---------|
        Frame 2 -----> |         |
        Frame 3 -----> | network | -----> Frame 3' -----> Loss
        Frame 4 -----> |         |
        Frame 5 -----> |---------|

        In this case, we have ``num_net_input_frames=5`` and
        ``num_net_output_frames=1``. Also, since there is no temporal
        supervision for the outputs, ``num_data_lq_frames=num_net_input_frames=5``
        and ``num_data_gt_frames=num_net_output_frames=1``, which is the EDVR
        case.

        A second case is multi-input frames and multi-output frames:

        Frame 1 -----> |---------| -----> Frame 1' -----> |------|
        Frame 2 -----> |         | -----> Frame 2' -----> |      |
        Frame 3 -----> | network | -----> Frame 3' -----> | loss | -----> Loss
        Frame 4 -----> |         | -----> Frame 4' -----> |      |
        Frame 5 -----> |---------| -----> Frame 5' -----> |------|

        In this case, `num_data_lq_frames=num_net_input_frames=5`` and
        ``num_data_gt_frames=num_net_output_frames=5``.

        Third case, multi-input frames, single center output frame and with
        temporal supervision:

        Frame 1 -----> |---------|
        Frame 2 -----> |         |
        Frame 3 -----> |         | -----> Frame 3' -----> |------|
        Frame 4 -----> |         | -----> Frame 4' -----> |      |
        Frame 5 -----> | network | -----> Frame 5' -----> | loss | -----> Loss
        Frame 6 -----> |         | -----> Frame 6' -----> |      |
        Frame 7 -----> |         | -----> Frame 7' -----> |------|
        Frame 8 -----> |         |
        Frame 9 -----> |---------|

        In the 3rd case, ``num_data_lq_frames=9``, ``num_net_input_frames=5``,
        ``num_data_gt_frames=5``, ``num_net_output_frames=1``, which satisfies:

        ``num_data_lq_frames = num_data_gt_frames + num_net_input_frames - num_net_output_frames``

        During inference, the network is still multi-input frames and
        single center output frame (same with the 1st case), while temporal
        loss can be applied to the network during training.

    Args:
        cfg: Configuration loaded from the *.yaml file.
    """
    def __init__(self, cfg):
        self.model_name = cfg.model.name
        self.scale = cfg.model.scale
        self.num_net_input_frames = cfg.model.num_net_input_frames
        self.num_net_output_frames = cfg.model.num_net_output_frames
        self.num_data_lq_frames = cfg.data.num_data_lq_frames
        self.num_data_gt_frames = cfg.data.num_data_gt_frames

        self.input_color_space = cfg.data.color_space
        self.num_in_channels = 3
        self.is_train = cfg.mode == 'train'

        self.cfg = cfg
        self.lq = None   # input low-quality
        self.gt = None   # groundtruth
        self.hq = None   # output high-quality
        self.generative_model_scope = cfg.model.scope
        self.output_dir = cfg.train.output_dir

    @property
    def output_node(self):
        """Obtain the default output result of the network

        Return:
            A 4D [N, H, W, C] or 5D [N, T, H, W, C] tensorflow tensor.
        """
        return self.hq

    @property
    def input_node(self):
        """Obtain the default input node of the network.

        Return:
            A 4D [N, H, W, C] or 5D [N, T, H, W, C] tensor.
        """
        return self.lq

    def parameters(self, scope=''):
        """Obtain the trainable parameters given the scope.

        Args:
            scope: str, the parameter scope. If is empty, return all the
                parameters in the top scope ``self.generative_model_scope``.

        Return:
            A list of parameter tensor in the given scope.
        """
        if scope == '':
            return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     self.generative_model_scope)
        else:
            return tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=scope)

    def calculate_content_loss(self, gt, hq):
        """Compute the pixel-wise content loss. The loss will be added to
        ``name_space.GeneratorLoss``.

        Args:
            gt: tensor, predictions of the network.
            hq: tensor, ground-truth of the training.

        """
        eps = self.cfg.loss.content.loss_margin
        reduction = self.cfg.loss.content.loss_reduction

        loss = get_loss(self.cfg.loss.content.loss_type, gt, hq, eps=eps)
        # reduction strategy is adjusted to ascend platform, to keep the
        # gradient neither too large (in case of overflow) nor too small
        # (in case gradient vanishing because of the CUBE operator).
        if reduction == 'mean':
            loss = tf.reduce_sum(tf.reduce_mean(loss, axis=[1, 2]))
        elif reduction == 'sum':
            loss = tf.reduce_mean(tf.reduce_sum(loss, axis=[1, 2, 3]))
        else:
            raise NotImplementedError
        name_space.add_to_collection(
            name_space.GeneratorLoss,
            f'content {self.cfg.loss.content.loss_type}',
            loss)
        return loss

    def calculate_perceptual_loss(self, gt, hq):
        """Compute perceptual loss. The loss will be added to
        ``name_space.GeneratorLoss``.

        Args:
            gt: tensor, predictions of the network.
            hq: tensor, ground-truth of the training.
        """
        # perceptual loss will be weighted in build_perceptual_loss
        perceptual_config = convert_to_dict(self.cfg.loss.perceptual, [])
        perceptual_loss = build_perceptual_loss(gt, hq, perceptual_config)
        perceptual_loss = perceptual_loss * self.cfg.loss.perceptual.loss_weight

        name_space.add_to_collection(
            name_space.GeneratorLoss,
            'perceptual',
            perceptual_loss)
        return perceptual_loss

    def calculate_border_loss(self, gt, hq):
        """Compute edge loss. The loss will be added to ``name_space.GeneratorLoss``.

        Args:
            gt: tensor, predictions of the network.
            hq: tensor, ground-truth of the training.
        """

        hq_edge = get_edges(gt, method=self.cfg.loss.edge.method)
        gt_edge = get_edges(hq, method=self.cfg.loss.edge.method)
        edge_loss = get_loss(self.cfg.loss.content.loss_type, hq_edge, gt_edge)
        edge_loss = tf.reduce_sum(tf.reduce_mean(edge_loss, axis=[1, 2]))
        edge_loss = edge_loss * self.cfg.loss.edge.loss_weight

        name_space.add_to_collection(
            name_space.GeneratorLoss,
            'edge',
            edge_loss)

        return edge_loss

    def calculate_content_style_loss(self, gt, hq):
        """Compute style loss. The loss will be added to
        ``name_space.GeneratorLoss``.

        Args:
            gt: tensor, predictions of the network.
            hq: tensor, ground-truth of the training.
        """
        # perceptual loss will be weighted in build_perceptual_loss
        perceptual_config = convert_to_dict(self.cfg.loss.perceptual, [])
        perceptual_loss = build_content_style_loss(gt, hq, perceptual_config)
        perceptual_loss = perceptual_loss * self.cfg.loss.perceptual.loss_weight

        name_space.add_to_collection(
            name_space.GeneratorLoss,
            'style',
            perceptual_loss)
        return perceptual_loss

    def calculate_adversarial_loss(self, gt, hq):
        """Compute adversarial loss. The loss will be added to
        ``name_space.GeneratorLoss``.

        Args:
            gt: tensor, predictions of the network.
            hq: tensor, ground-truth of the training.
        """

        # discriminator loss will be weighted and added to name_space in build_
        # adversarial_loss
        _ = build_adversarial_loss(gt, hq, self.cfg)

    def build_losses(self, *args, **kwargs):
        """Compute all the losses, including pixel-wise content loss (required),
        perceptual and perceptual style loss (if loss_weight > 0), edge loss (
        if loss_weight > 0), and adversarial loss (if loss_weight > 0).
        """
        # all losses should be added to name_space collections
        gt = tf.cast(self.gt, tf.float32)
        hq = tf.cast(self.hq, tf.float32)

        hq = tf.reshape(hq, gt.shape)

        _ = self.calculate_content_loss(gt, hq)
        if self.cfg.loss.edge.loss_weight > 0:
            _ = self.calculate_border_loss(gt, hq)

        if self.cfg.loss.perceptual.loss_weight > 0:
            _ = self.calculate_perceptual_loss(gt, hq)

        if self.cfg.loss.adversarial.loss_weight > 0:
            self.calculate_adversarial_loss(gt, hq)

    def build_metrics(self, *args, **kwargs):
        # Reserved for evaluation.
        pass

    def prepare_placeholder(self, size):
        """Prepare placeholder for **inference** phase, given the input size.

        Args:
            size: tuple/list, including [batchsize, (h, w)]

        Returns:
            None
        """
        # Note: this function is only for non-train mode
        if self.lq is not None:
            pass
        b, spatial = size

        if self.cfg.model.input_format_dimension == 5:
            if b is None or b < 0:
                b = None
            self.lq = tf.placeholder(
                tf.float32,
                shape=[b,
                       self.num_net_input_frames,
                       *spatial,
                       self.num_in_channels],
                name='L_input')
        elif self.cfg.model.input_format_dimension == 4:
            # Mainly used for offline model inference for speeding up in the
            # AIPP on Ascend 310
            if b is None or b < 0:
                self.lq = tf.placeholder(
                    tf.float32,
                    shape=[None,
                           *spatial,
                           self.num_in_channels],
                    name='L_input')
            else:
                self.lq = tf.placeholder(
                    tf.float32,
                    shape=[b*self.num_net_input_frames,
                           *spatial,
                           self.num_in_channels],
                    name='L_input')
        else:
            raise ValueError(f'Input format dimension only support 4 or 5, '
                             f'but got {self.cfg.model.input_format_dimension}')

    def build_graph(self, dataloader=None, input_size=None, *args, **kwargs):
        """Build tensorflow graph, network building, loss calculation, metrics
        calculation, etc.

        Args:
            dataloader: tf.Datasets, in training or evaluation phase.
            input_size: tuple or list, [b, (h, w)], for inference and freeze
                phase.

        Returns:
            None
        """
        if self.cfg.mode in ['freeze', 'inference']:
            assert input_size is not None
            self.prepare_placeholder(input_size)
        elif self.cfg.mode in ['train', 'eval']:
            assert dataloader is not None
            self.lq, self.gt = dataloader
        else:
            raise NotImplementedError

        # Forward propagation
        self.hq = self.build_generator(self.lq)

        if self.cfg.mode == 'train':
            name_space.add_to_collection(name_space.Summary, 'hq', self.gt)
            self.build_losses()
        elif self.cfg.mode == 'eval':
            name_space.add_to_collection(name_space.Summary, 'hq', self.gt)
            self.build_metrics()

        if self.cfg.mode in ['eval', 'inference', 'freeze'] and \
                self.cfg.model.convert_output_to_uint8:
            self.hq = tf.cast(
                tf.round(
                    tf.clip_by_value(
                        self.hq * 255.,
                        0.,
                        255.)),
                tf.uint8
            )

        # Setup the output node for inference without network file.
        self.hq = tf.identity(self.hq, name='HQ_output')

        name_space.add_to_collections((name_space.Summary,
                                       name_space.InputField),
                                      'lq',
                                      self.lq)
        name_space.add_to_collections((name_space.Summary,
                                       name_space.OutputField),
                                      'gt',
                                      self.hq)

    def build_generator(self, lq, *args, **kwargs):
        """Building the forward network. This is the interface every derived
        network class should implement.

        Args:
            lq: tensor, input frames. 4D or 5D tensor.

        Returns:
            None
        """
        raise NotImplementedError

    def dump_summary(self, step, summary_dict):
        """Function to visualize the intermediate training status.
        In case, tensorboard is not available, one can use this function to
        check the intermediate training or evaluation results.

        Args:
            step: int, training step
            summary_dict: dict, contains all the results.

        Returns:
            None
        """
        pass
